# Predict unexpressed genes (labeled as 0) vs. expressed genes (labeled as 1)
# Predictors used: 1500bp promoters, 1500bp terminators or both 

import pandas as pd
import Bio.SeqIO
import numpy as np
import math
import random
import pickle
import os
from sys import argv
from keras.metrics import mean_squared_error
from keras.layers import Dense,Activation,Flatten,Dropout,Conv2D, MaxPooling2D
from keras.optimizers import Adam
from keras.models import Sequential,Model
from keras.utils import np_utils
from keras.callbacks import EarlyStopping
from datetime import datetime
from keras import backend as K
from sklearn.metrics import confusion_matrix
from deeplift.dinuc_shuffle import dinuc_shuffle

SEQ_FILE_PRO='promoters.fa'
SEQ_FILE_TER='terminators.fa'
DATA_FILE='un_vs_expr.csv'
RESULT_DIRECTORY='trained_models_BEM/'

if not os.path.exists(RESULT_DIRECTORY):
   os.makedirs(RESULT_DIRECTORY)

data=pd.read_csv(DATA_FILE,sep='\t')
data=data.set_index('Geneid')

# empty lists 
geneIDs=[]
categories=[]

seqs=[]
seqs_pro=[]
seqs_ter=[]

seqs_dshuffle=[]
seqs_pro_dshuffle=[]
seqs_ter_dshuffle=[]

seqs_sshuffle=[]
seqs_pro_sshuffle=[]
seqs_ter_sshuffle=[]

# read sequences and at the same time generate single-nucleotide-shuffled sequences
for x,y in zip(Bio.SeqIO.parse(SEQ_FILE_PRO,'fasta'), Bio.SeqIO.parse(SEQ_FILE_TER,'fasta') ):
     geneID=x.id
     seq_pro=str(x.seq)
     seq_pro_sshuffle=list(seq_pro)
     random.shuffle(seq_pro_sshuffle)
     seq_pro_sshuffle=''.join(seq_pro_sshuffle)
     seq_pro_dshuffled=dinuc_shuffle(seq_pro)   
     seq_ter=str(y.seq)
     seq_ter_sshuffle=list(seq_ter)
     random.shuffle(seq_ter_sshuffle)
     seq_ter_sshuffle=''.join(seq_ter_sshuffle)
     seq_ter_dshuffled=dinuc_shuffle(seq_ter)   
     seq=seq_pro+seq_ter
     seq_sshuffle=seq_pro_sshuffle+seq_ter_sshuffle
     seq_dshuffle=seq_pro_dshuffled+seq_ter_dshuffled

     if geneID in data.index:
         seqs.append(seq)
         seqs_sshuffle.append(seq_sshuffle)
         seqs_pro.append(seq_pro)
         seqs_pro_sshuffle.append(seq_pro_sshuffle)
         seqs_ter.append(seq_ter)
         seqs_ter_sshuffle.append(seq_ter_sshuffle)
         geneIDs.append(geneID)
         category=data['category'][geneID]
         if category=='expressed':
              categories.append(0)
         else:
              categories.append(1)

# one-hot encoding
dict={'A':[1,0,0,0],'C':[0,1,0,0],'G':[0,0,1,0],'T':[0,0,0,1]}
# one-hot encoding for methylation sites
# dict={'X':[1,0,0,0],'Y':[0,1,0,0],'Z':[0,0,1,0],'C':[0,0,0,1],'N':[0,0,0,0],'A':[0,0,0,0],'T':[0,0,0,0],'G':[0,0,0,0]}

def one_hot_encoding(seq):
    one_hot_encoded=np.zeros(shape=(4,len(seq)))
    for i,nt in enumerate(seq):
        one_hot_encoded[:,i]=dict[nt]
    return one_hot_encoded

one_hot_seqs=np.expand_dims(np.array([ one_hot_encoding(seq) for seq in seqs],dtype=np.float32),3)
one_hot_seqs_pro=np.expand_dims(np.array([ one_hot_encoding(seq) for seq in seqs_pro],dtype=np.float32),3)
one_hot_seqs_ter=np.expand_dims(np.array([ one_hot_encoding(seq) for seq in seqs_ter],dtype=np.float32),3)

# masking
one_hot_seqs[:,:,1000:1003,:]=0
one_hot_seqs[:,:,1997:2000,:]=0
one_hot_seqs_pro[:,:,1000:1003,:]=0
one_hot_seqs_ter[:,:,497:500,:]=0

one_hot_seqs_dshuffle=np.expand_dims(np.array([ one_hot_encoding(seq) for seq in seqs_dshuffle],dtype=np.float32),3)
one_hot_seqs_pro_dshuffle=np.expand_dims(np.array([ one_hot_encoding(seq) for seq in seqs_pro_dshuffle],dtype=np.float32),3)
one_hot_seqs_ter_dshuffle=np.expand_dims(np.array([ one_hot_encoding(seq) for seq in seqs_ter_dshuffle],dtype=np.float32),3)

one_hot_seqs_sshuffle=np.expand_dims(np.array([ one_hot_encoding(seq) for seq in seqs_sshuffle],dtype=np.float32),3)
one_hot_seqs_pro_sshuffle=np.expand_dims(np.array([ one_hot_encoding(seq) for seq in seqs_pro_sshuffle],dtype=np.float32),3)
one_hot_seqs_ter_sshuffle=np.expand_dims(np.array([ one_hot_encoding(seq) for seq in seqs_ter_sshuffle],dtype=np.float32),3)

geneIDs=np.array(geneIDs)
categories=np_utils.to_categorical(np.array(categories),2)

# architecture of neural network
def build(window_size):
                model=Sequential()

                model.add(Conv2D(64,kernel_size=(4,8),padding='valid',input_shape=[4,window_size,1]))
                model.add(Activation('relu'))
                model.add(Conv2D(64,kernel_size=(1,8),padding='same'))
                model.add(Activation('relu'))
                model.add(MaxPooling2D(pool_size=(1,8),strides=(1,8),padding='same'))
                model.add(Dropout(0.25))

                model.add(Conv2D(128,kernel_size=(1,8),padding='same'))
                model.add(Activation('relu'))
                model.add(Conv2D(128,kernel_size=(1,8),padding='same'))
                model.add(Activation('relu'))
                model.add(MaxPooling2D(pool_size=(1,8),strides=(1,8),padding='same'))
                model.add(Dropout(0.25))

                model.add(Conv2D(64,kernel_size=(1,8),padding='same'))
                model.add(Activation('relu'))
                model.add(Conv2D(64,kernel_size=(1,8),padding='same'))
                model.add(Activation('relu'))
                model.add(MaxPooling2D(pool_size=(1,8),strides=(1,8),padding='same'))
                model.add(Dropout(0.25))

                model.add(Flatten())
                model.add(Dense(128))
                model.add(Activation('relu'))
                model.add(Dropout(0.25))
                model.add(Dense(64))
                model.add(Activation('relu'))
                model.add(Dense(2))
                model.add(Activation('softmax'))
                return model
# train models and save results
def train_models(TRAINING_ONE_HOT_SEQS,TESTING_ONE_HOT_SEQS,training_indices,testing_indices,split,test_subsample,shuffle,predictor):
      #splitting into training set and test set
      TRAINING_ONE_HOT_SEQS=TRAINING_ONE_HOT_SEQS[training_indices]
      TESTING_ONE_HOT_SEQS=TESTING_ONE_HOT_SEQS[testing_indices]
      TRAINING_CATEGORIES=categories[training_indices]
      TESTING_CATEGORIES=categories[testing_indices]
      TESTING_GENES=geneIDs[testing_indices]

      LENGTH=TRAINING_ONE_HOT_SEQS.shape[2]
      callbacks = [EarlyStopping(monitor='val_loss', patience=3, verbose=0)]
      model=build(LENGTH)
      model.compile(loss='binary_crossentropy',optimizer='Adam',metrics=['accuracy'])
      model.fit(x=TRAINING_ONE_HOT_SEQS,y=TRAINING_CATEGORIES,batch_size=256,epochs=40,validation_data=(TESTING_ONE_HOT_SEQS,TESTING_CATEGORIES),callbacks=callbacks)

      prediction=model.predict(TESTING_ONE_HOT_SEQS)
      predicted_categories= np.argmax(prediction,axis=1)
      real_categories=np.argmax(TESTING_CATEGORIES,axis=1)
      tn,fp,fn,tp=confusion_matrix(real_categories,predicted_categories).ravel()
      test_total=tp+tn+fp+fn
      accuracy=(tp+tn)/(test_total*1.0)

      now=datetime.now().strftime('%Y%m%d%H%M%S')
      model.save(RESULT_DIRECTORY+"MODEL_"+now)
      pickle.dump([TESTING_GENES,TESTING_CATEGORIES,prediction],open(RESULT_DIRECTORY+"PICKLE_"+now,'wb'))
      with open(RESULT_DIRECTORY+'SUMMARY_FILE', "a") as result_file:
         result_file.write("\t".join([now,split,test_subsample,shuffle,str(tp),str(tn),str(fp),str(fn),str(accuracy),str(test_total),predictor]))
         result_file.write("\n")
      del model
      K.clear_session()

# family-aware 10 times 5-fold cross-validation to evaluate the performance of the model 
# 3 types of input sequences: PROMOTER, TERMINATOR, PROMOTER+TERMINATOR
# 3 types of shuffling of input sequences: None, dinucleotide, single-nucleotide
# sequence shuffling performed only on training set, not testing set
# train 10*5*3*3=450 models in total

splits=data.columns.values[-10:]
subsamples=['subsample1','subsample2','subsample3','subsample4','subsample5']

for split in splits:
   for test_subsample in subsamples:
       testing_genes = data[data[split]==test_subsample].index
       testing_genes = np.array(list(set(testing_genes).intersection(set(geneIDs))))
       # In testing set, downsample expressed genes to make them balanced with unexpressed genes
       unexpressed_testing_genes=[gene for gene in testing_genes if data['category'][gene]=='unexpressed'  ]
       expressed_testing_genes=[gene for gene in testing_genes if data['category'][gene]=='expressed']
       expressed_testing_genes=np.random.choice(expressed_testing_genes,len(unexpressed_testing_genes),replace=False)
       testing_genes=np.concatenate((expressed_testing_genes,unexpressed_testing_genes),axis=0)

       training_genes= data[data[split]!=test_subsample].index
       training_genes= np.array(list(set(training_genes).intersection(set(geneIDs))))
       # In training set, downsample expressed genes to make them balanced with unexpressed genes
       unexpressed_training_genes=[gene for gene in training_genes if data['category'][gene]=='unexpressed'  ]
       expressed_training_genes=[gene for gene in training_genes if data['category'][gene]=='expressed']
       expressed_training_genes=np.random.choice(expressed_training_genes,len(unexpressed_training_genes),replace=False)
       training_genes=np.concatenate((expressed_training_genes,unexpressed_training_genes),axis=0)

       training_indices=np.array([np.where(geneIDs==element)[0][0] for element in training_genes])
       testing_indices=np.array([np.where(geneIDs==element)[0][0] for element in testing_genes])
       np.random.shuffle(training_indices)
       train_models(one_hot_seqs,               one_hot_seqs,      training_indices,testing_indices,split,test_subsample,'None','Pro_and_Ter')
       train_models(one_hot_seqs_pro,           one_hot_seqs_pro,  training_indices,testing_indices,split,test_subsample,'None','Pro')
       train_models(one_hot_seqs_ter,           one_hot_seqs_ter,  training_indices,testing_indices,split,test_subsample,'None','Ter')
       train_models(one_hot_seqs_dshuffle,      one_hot_seqs,      training_indices,testing_indices,split,test_subsample,'D_Shuffle','Pro_and_Ter')
       train_models(one_hot_seqs_pro_dshuffle,  one_hot_seqs_pro,  training_indices,testing_indices,split,test_subsample,'D_Shuffle','Pro')
       train_models(one_hot_seqs_ter_dshuffle,  one_hot_seqs_ter,  training_indices,testing_indices,split,test_subsample,'D_Shuffle','Ter')
       train_models(one_hot_seqs_sshuffle,      one_hot_seqs,      training_indices,testing_indices,split,test_subsample,'S_Shuffle','Pro_and_Ter')
       train_models(one_hot_seqs_pro_sshuffle,  one_hot_seqs_pro,  training_indices,testing_indices,split,test_subsample,'S_Shuffle','Pro')
       train_models(one_hot_seqs_ter_sshuffle,  one_hot_seqs_ter,  training_indices,testing_indices,split,test_subsample,'S_Shuffle','Ter')
